#!/usr/bin/env python
# -*- coding: utf-8 -*-
from tqsdk import TqApi, TqAuth, TargetPosTask,TqKq,TqAccount
import pandas as pd
import numpy as np
from datetime import datetime
import yaml
from typing import Union, List, Any, Optional
import time
import os
import pickle
class Strategy(object):
    def __init__(self,configfile):
        self.configfile=configfile
        dict_cfg=self.load_config(self.configfile)
        self.params = dict_cfg['Params']
        self.get_params()
        self.get_data()
        print(self.stra_name + ' __init__  {}的总仓位:{} pos:{} MarketPosition:{}  Quote_symbol:{} Trade_symbol:{} acc_name:{} is_testing:{}'.format(self.Trade_symbol.split('.')[1],self.position.pos,self.pos, self.MarketPosition,self.Quote_symbol,self.Trade_symbol,self.acc_name,self.is_testing))

    def get_params(self):
        self.Quote_symbol=self.params['Quote_symbol']#'KQ.m@'+symbol#KQ.m@CFFEX.IF - 中金所IF品种主连合约
        self.Trade_symbol = self.params['Trade_symbol']
        self.ID=self.params['ID']
        self.stra_name=self.params['stra_name']
        self.port=self.params['port']

        self.acc_number=self.params['acc_number']
        self.acc_name=self.params['acc_name']
        self.pwd=self.params['pwd']
        self.is_testing=self.params['is_testing']

        if self.is_testing==1:
            self.api = TqApi(TqKq(), auth=',', web_gui=':' + self.port, disable_print=True)
        if self.is_testing==0:
            self.api = TqApi(TqAccount(self.acc_name, self.acc_number, self.pwd), auth=',', web_gui=':' + self.port, disable_print=True)
        self.position = self.api.get_position(self.Trade_symbol)
        self.account = self.api.get_account()
        self.target_pos = TargetPosTask(self.api, self.Trade_symbol)
        self.lots=self.params['lots']
        self.pos=self.params['pos']
        self.Interval=self.params['Interval']
        self.path_in=self.params['path_in']
        self.path_out=self.params['path_out']
        self.MarketPosition = self.params['MarketPosition']
        self.entryprice=self.params['entryprice']
        self.entrydate=self.params['entrydate']
        self.exitprice=self.params['exitprice']
        self.exitdate=self.params['exitdate']
    def closeapi(self):
        self.api.close()

    def __getstate__(self):
        return self.__dict__

    def get_data(self):

        self.klines=self.api.get_kline_serial(self.Quote_symbol,self.Interval,data_length=8964)


    def write_trade_log(self,item):
        if not os.path.exists(self.path_out + self.stra_name + '\\' ):
            os.mkdir(self.path_out + self.stra_name + '\\')
            df_trade_log = pd.DataFrame([])
            #df_trade_log.columns = ['stra_name', 'symbol', 'datetime', 'buyorsell', 'price', 'lots']
            df_trade_log.to_csv(self.path_out + self.stra_name + '\\' + self.stra_name+'_trade_log.csv')
        if not os.path.isfile(self.path_out + self.stra_name + '\\' + self.stra_name+'_trade_log.csv'):
            df_trade_log = pd.DataFrame([])
            df_trade_log.to_csv(self.path_out + self.stra_name + '\\' + self.stra_name+'_trade_log.csv')

        df_trade_log = pd.read_csv(self.path_out + self.stra_name + '\\' + self.stra_name+'_trade_log.csv',index_col=0)
        if len(df_trade_log.values.tolist())!=0:
            df_trade_log.columns = ['stra_name', 'symbol', 'datetime', 'buyorsell', 'price', 'lots']
        trade_log = df_trade_log.values.tolist()
        trade_log.append(item)
        df_trade_log = pd.DataFrame(trade_log)
        df_trade_log.columns = ['stra_name', 'symbol', 'datetime', 'buyorsell', 'price', 'lots']
        df_trade_log.to_csv(self.path_out + self.stra_name + '\\' + self.stra_name+'_trade_log.csv')
        print(self.stra_name+' '+self.Trade_symbol+' one log write to trade_log ')
    def buy(self):
        info_dict = {'position': self.position.pos, 'pos': self.pos, 'acc_name': self.acc_name, 'stra_name': self.stra_name}
        self.PRINT('before buy',info_dict)
        if self.MarketPosition<0:
            self.buytocover()


        quote = self.api.get_quote(self.Trade_symbol)

        order = self.api.insert_order(symbol=self.Trade_symbol, direction="BUY", offset="OPEN", volume=self.lots,
                                 limit_price=quote.last_price)
        while True:
            self.api.wait_update()
            # print("单状态: %s, 已成交: %d 手" % (order.status, order.volume_orign - order.volume_left))
            if order.status == "FINISHED":
                self.MarketPosition=1
                self.pos=self.lots
                # trade_log
                date_str=time.strftime("%Y%m%d %H:%M:%S", time.localtime())
                item=[self.stra_name,self.Trade_symbol,date_str,'buy',order.trade_price,self.lots]
                self.write_trade_log(item)
                self.entrydate=date_str
                self.entryprice=order.trade_price#可以更改
                params = self.load_config(self.configfile)
                params['Params']['entryprice'] = self.entryprice
                params['Params']['entrydate'] = self.entrydate
                self.update_cfg_yml(params)
                break

        info_dict = {'position': self.position.pos, 'pos': self.pos, 'acc_name': self.acc_name,
                     'stra_name': self.stra_name}
        self.PRINT('after buy', info_dict)


    def sell(self):
        if self.MarketPosition>0:

            info_dict = {'position': self.position.pos, 'pos': self.pos, 'acc_name': self.acc_name,
                         'stra_name': self.stra_name}
            self.PRINT('before sell', info_dict)
            quote = self.api.get_quote(self.Trade_symbol)
            exchange = self.Trade_symbol.split('.')[0]
            today = time.strftime("%Y-%m-%d", time.localtime())

            if (exchange == 'SHFE' and today == self.entrydate[:10]) or (
                    self.Trade_symbol[:6].upper() == 'INE.SC' and today == self.entrydate[:10]):
                order = self.api.insert_order(symbol=self.Trade_symbol, direction="SELL", offset="CLOSETODAY", volume=self.lots,
                                         limit_price=quote.last_price)
            else:
                order = self.api.insert_order(symbol=self.Trade_symbol, direction="SELL", offset="CLOSE", volume=self.lots,
                                         limit_price=quote.last_price)

            while True:
                self.api.wait_update()
                # print("单状态: %s, 已成交: %d 手" % (order.status, order.volume_orign - order.volume_left))
                if order.status == "FINISHED":
                    self.MarketPosition = 0
                    self.pos=0
                    # trade_log
                    #date_str = time.strftime("%Y-%m-%d %H:%M:%S",
                    #                         time.localtime(self.klines.datetime.iloc[-1] / 1000000000))
                    date_str = time.strftime("%Y%m%d %H:%M:%S", time.localtime())
                    item = [self.stra_name, self.Trade_symbol, date_str, 'sell', order.trade_price, self.lots]
                    self.write_trade_log(item)
                    self.exitdate=date_str
                    self.exitprice=order.trade_price
                    params = self.load_config(self.configfile)
                    params['Params']['exitprice'] = self.exitprice
                    params['Params']['exitdate'] = self.exitdate
                    self.update_cfg_yml(params)


                    break

            info_dict = {'position': self.position.pos, 'pos': self.pos, 'acc_name': self.acc_name,
                         'stra_name': self.stra_name}
            self.PRINT('after sell', info_dict)


    def sellshort(self):
        info_dict = {'position': self.position.pos, 'pos': self.pos, 'acc_name': self.acc_name,
                     'stra_name': self.stra_name}
        self.PRINT('before sellshort', info_dict)

        if self.MarketPosition>0:
            self.sell()
        quote = self.api.get_quote(self.Trade_symbol)
        order = self.api.insert_order(symbol=self.Trade_symbol, direction="SELL", offset="OPEN", volume=self.lots,
                                 limit_price=quote.last_price)
        while True:
            self.api.wait_update()
            # print("单状态: %s, 已成交: %d 手" % (order.status, order.volume_orign - order.volume_left))
            if order.status == "FINISHED":
                self.MarketPosition = -1
                self.pos=-self.lots
                # trade_log
                # date_str=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.klines.datetime.iloc[-1] / 1000000000))
                date_str = time.strftime("%Y%m%d %H:%M:%S", time.localtime())
                item = [self.stra_name, self.Trade_symbol, date_str, 'sellshort', order.trade_price, self.lots]
                self.write_trade_log(item)
                self.entrydate = date_str
                self.entryprice = order.trade_price

                params = self.load_config(self.configfile)
                params['Params']['entryprice'] = self.entryprice
                params['Params']['entrydate'] = self.entrydate
                self.update_cfg_yml(params)
                break

        info_dict = {'position': self.position.pos, 'pos': self.pos, 'acc_name': self.acc_name,
                     'stra_name': self.stra_name}
        self.PRINT('after sellshort', info_dict)


    def buytocover(self):
        if self.MarketPosition<0:
            info_dict = {'position': self.position.pos, 'pos': self.pos, 'acc_name': self.acc_name,
                         'stra_name': self.stra_name}
            self.PRINT('before buytocover', info_dict)
            quote = self.api.get_quote(self.Trade_symbol)
            exchange = self.Trade_symbol.split('.')[0]
            today = time.strftime("%Y-%m-%d", time.localtime())

            if (exchange == 'SHFE' and today == self.entrydate[:10]) or (
                    self.Trade_symbol[:6].upper() == 'INE.SC' and today == self.entrydate[:10]):
                order = self.api.insert_order(symbol=self.Trade_symbol, direction="BUY", offset="CLOSETODAY", volume=self.lots,
                                         limit_price=quote.last_price)

            else:
                order = self.api.insert_order(symbol=self.Trade_symbol, direction="BUY", offset="CLOSE", volume=self.lots,
                                         limit_price=quote.last_price)

            while True:
                self.api.wait_update()
                # print("单状态: %s, 已成交: %d 手" % (order.status, order.volume_orign - order.volume_left))
                if order.status == "FINISHED":
                    self.MarketPosition = 0
                    self.pos=0
                    # trade_log
                    # date_str = time.strftime("%Y-%m-%d %H:%M:%S",
                    #                         time.localtime(self.klines.datetime.iloc[-1] / 1000000000))
                    date_str = time.strftime("%Y%m%d %H:%M:%S", time.localtime())
                    item = [self.stra_name, self.Trade_symbol, date_str, 'buytocover', order.trade_price,
                            self.lots]
                    self.write_trade_log(item)

                    self.exitdate=date_str
                    self.exitprice=order.trade_price
                    params = self.load_config(self.configfile)
                    params['Params']['exitprice'] = self.exitprice
                    params['Params']['exitdate'] = self.exitdate
                    self.update_cfg_yml(params)


                    break

            info_dict = {'position': self.position.pos, 'pos': self.pos, 'acc_name': self.acc_name,
                         'stra_name': self.stra_name}
            self.PRINT('after buytocover', info_dict)

    def stra(self):
        '''
        while True:
            self.api.wait_update()
            if self.api.is_changing(self.klines.iloc[-1], "datetime"):
                # datetime: 自unix epoch(1970-01-01 00:00:00 GMT)以来的纳秒数
                print("新K线", datetime.fromtimestamp(self.klines.iloc[-1]["datetime"] / 1e9))
                # 判断最后一根K线的收盘价是否有变化
            if self.api.is_changing(self.klines.iloc[-1], "close"):
                # klines.close返回收盘价序列
                print("K线变化", datetime.fromtimestamp(self.klines.iloc[-1]["datetime"] / 1e9), self.klines.close.iloc[-1])
        self.api.close()
        '''

    def load_config(self,path_filename):
        with open(path_filename, encoding='utf-8') as stra_cfg_json_file:
            dict_cfg = yaml.load(stra_cfg_json_file, Loader=yaml.FullLoader)

        return dict_cfg
    def dump_config(self,path_filename,params):

        with open(path_filename, "w", encoding="utf-8") as f:
            yaml.dump(params, f, Dumper=yaml.Dumper,default_flow_style=False,encoding='utf-8',allow_unicode=True)

    def update_cfg_yml(self,params):
        #params = self.load_config(self.configfile)
        params['Params']['MarketPosition'] = self.MarketPosition
        params['Params']['pos']=self.pos

        self.dump_config(self.configfile, params)
    def PRINT(self,beforeorafter,info_dict):
        position=info_dict['position']
        stra_name=info_dict['stra_name']
        pos=info_dict['pos']
        acc_name=info_dict['acc_name']

        print(time.strftime("%Y-%m-%d %H:%M:%S",
        time.localtime()) + ' ' + stra_name + ' '+beforeorafter+' acc_name: {}  总仓位: {}   '
                                                                'pos:{}'.format(acc_name,position, pos))
